#pragma once
// from pytorch.aten
#include "tensorview.h"
#include <type_traits>
namespace tv {
namespace cuda {

template <typename T1, typename T2> inline int DivUp(const T1 a, const T2 b) {
  return (a + b - 1) / b;
}

// Use 1024 threads per block, which requires cuda sm_2x or above
constexpr int CUDA_NUM_THREADS = 1024;
// CUDA: number of blocks for threads.

inline int getNumThreads(const int N) {
  if (N > CUDA_NUM_THREADS) {
    return CUDA_NUM_THREADS;
  }
  return DivUp(N, 32) * 32;
}

inline int getBlocks(const int N) {
  TV_ASSERT_RT_ERR(N > 0,
                   "CUDA kernel launch blocks must be positive, but got N=", N);
  return DivUp(N, getNumThreads(N));
}

} // namespace cuda

} // namespace tv